Stable diffusion代码
2025年9月15日
16:36
在看代码前需明确两个关键前提:
总结:关键的是理解UNET_AttentionBlock中的交叉注意力部分,把图片看成是序列的列式,在NLP中序列的长度是一段话中token的数量,在这里序列的长度相当于是图片的H*W,图片的通道数相当于是特征的维度,即token embedding的长度。
import torch
from torch
import nn
from torch.nn
import functional as F
from
attention import SelfAttention, CrossAttention
class
TimeEmbedding(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.linear_1 = nn.Linear(n_embd, 4 *
n_embd)
self.linear_2 = nn.Linear(4 * n_embd, 4
* n_embd)
def forward(self, x): # x: (1, 320) → t的初始嵌入(通常是正弦余弦编码
# x: (1, 320)
# (1, 320) -> (1, 1280)
x = self.linear_1(x)
# (1, 1280) -> (1, 1280)
x = F.silu(x)
# (1, 1280) -> (1, 1280)
x = self.linear_2(x)
return x # 输出:(1,
1280)
残差块是 UNET 的「基本特征变换单元」,负责:1)调整特征通道数;2)将时间嵌入注入空间特征;3)通过残差连接避免梯度消失。
class UNET_ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_time=1280): # n_time=TimeEmbedding输出维度
super().__init__()
# 第一分支:特征图处理
self.groupnorm_feature = nn.GroupNorm(32, in_channels) # 分组归一化(32组,适合小批量)
self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # 3×3卷积(不改变分辨率)
# 第二分支:时间嵌入处理
self.linear_time = nn.Linear(n_time, out_channels) # 时间嵌入→与特征通道数一致
# 第三分支:融合后处理
self.groupnorm_merged = nn.GroupNorm(32, out_channels)
self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
# 残差连接:若通道数不同,用1×1卷积调整
self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, feature, time): # feature: (B, in_ch, H, W); time: (1, 1280)
residue = feature # 保存残差
# 1. 处理特征图
feature = self.groupnorm_feature(feature) # 归一化
feature = F.silu(feature) # 激活
feature = self.conv_feature(feature) # 通道数:in→out(分辨率不变)
# 2. 处理时间嵌入,并与特征图广播相加
time = F.silu(time) # 时间嵌入激活
time = self.linear_time(time) # 时间嵌入:1280→out_ch((1, out_ch))
merged = feature + time.unsqueeze(-1).unsqueeze(-1) # 时间嵌入扩维为(1, out_ch, 1, 1),与特征图广播相加, 时间信息注入:通过time.unsqueeze(-1).unsqueeze(-1)将时间向量(1, out_ch)扩维为(1, out_ch, 1, 1),可与任意分辨率的特征图(B, out_ch, H, W)广播相加,实现「时间信息→空间特征」的融合。
# 3. 融合后进一步处理
merged = self.groupnorm_merged(merged)
merged = F.silu(merged)
merged = self.conv_merged(merged)
# 4. 残差连接:输入残差 + 处理后特征
return merged + self.residual_layer(residue)
注意力块是扩散模型捕捉「空间依赖」和「条件依赖」的核心,包含自注意力(Self-Attention) 和交叉注意力(Cross-Attention) :
自注意力:捕捉 latent 内部像素间的空间关联(如 “猫的耳朵” 与 “猫的脸” 的位置关系);
交叉注意力:让 latent 特征与条件信息(如文本 embedding)对齐(如 “红色玫瑰” 的文本 embedding 引导生成红色花瓣)。
class UNET_AttentionBlock(nn.Module):
def __init__(self, n_head: int, n_embd: int, d_context=768): # d_context=文本embedding维度(如CLIP的768)
super().__init__()
self.channels = n_head * n_embd # 注意力总维度(多头注意力:head数×单头维度)
# 输入处理
self.groupnorm = nn.GroupNorm(32, self.channels)
self.conv_input = nn.Conv2d(self.channels, self.channels, kernel_size=1) # 1×1卷积调整特征
# 自注意力分支
self.layernorm_1 = nn.LayerNorm(self.channels)
self.attention_1 = SelfAttention(n_head, self.channels) # 自注意力(无外部条件)
# 交叉注意力分支
self.layernorm_2 = nn.LayerNorm(self.channels)
self.attention_2 = CrossAttention(n_head, self.channels, d_context) # 交叉注意力(需context)
# FFN分支(GeGLU激活)
self.layernorm_3 = nn.LayerNorm(self.channels)
self.linear_geglu_1 = nn.Linear(self.channels, 4 * self.channels * 2) # 输出分两部分(x和gate)
self.linear_geglu_2 = nn.Linear(4 * self.channels, self.channels)
# 输出处理
self.conv_output = nn.Conv2d(self.channels, self.channels, kernel_size=1)
def forward(self, x, context): # x: (B, C, H, W); context: (B, SeqLen, 768)(如文本embedding)
residue_long = x # 保存初始残差(跨整个注意力块)
# 1. 输入归一化与卷积
x = self.groupnorm(x)
x = self.conv_input(x)
# 2. 特征图→序列:(B, C, H, W) → (B, H×W, C)(注意力需序列格式)
n, c, h, w = x.shape
x = x.view(n, c, h * w).transpose(-1, -2) # 转置后:(B, H×W, C),H×W=序列长度,C=特征维度
# 3. 自注意力(带残差)
residue_short = x # 短残差(仅跨自注意力)
x = self.layernorm_1(x)
x = self.attention_1(x) # 自注意力:(B, H×W, C) → (B, H×W, C)
x += residue_short
# 4. 交叉注意力(带残差)
residue_short = x
x = self.layernorm_2(x)
x = self.attention_2(x, context) # 交叉注意力:用context引导x(如文本→图像特征)
x += residue_short
# 5. FFN(GeGLU激活,带残差)
residue_short = x
x = self.layernorm_3(x)
# GeGLU计算:split成x和gate,x * gelu(gate)(比ReLU更高效的激活)
x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) # 分成两部分:(B, H×W, 4C) 和 (B, H×W, 4C)
x = x * F.gelu(gate) # GeGLU核心操作
x = self.linear_geglu_2(x) # 4C→C
x += residue_short
# 6. 序列→特征图:(B, H×W, C) → (B, C, H, W)
x = x.transpose(-1, -2).view(n, c, h, w)
# 7. 最终残差连接(初始输入 + 注意力块输出)
return self.conv_output(x) + residue_long
形状转换原因:注意力机制(如 Transformer)的输入格式是「(批量,序列长度,特征维度)」,而 UNET 的特征是「(批量,通道数,高,宽)」,因此需将「高 × 宽」 flatten 为序列长度,通道数作为特征维度。
CrossAttention 的 context 来源:通常是 CLIP 模型输出的文本 embedding(如 Stable Diffusion 中,文本 “a cat” 会被 CLIP 编码为 (1, 77, 768),77 是最大文本长度,768 是特征维度)。
解码器decoder需通过上采样恢复 latent 分辨率(从 H/64→H/8),此模块用「最近邻插值」实现高效上采样,再用 3×3 卷积调整特征,避免插值后的模糊。而encoder部分不需要downsample模块,因为分辨率降低是直接通过conv2d卷积实现的。
class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x): # x: (B, C, H, W)
# 最近邻插值:H和W翻倍(如H/32→H/16)
x = F.interpolate(x, scale_factor=2, mode='nearest')
return self.conv(x) # 3×3卷积:保持通道数,优化插值后的特征
class
SwitchSequential(nn.Sequential):
def forward(self, x, context, time):
for layer in self:
if isinstance(layer,
UNET_AttentionBlock):
x = layer(x, context)
elif isinstance(layer,
UNET_ResidualBlock):
x = layer(x, time)
else:
x = layer(x)
return x
UNET 是扩散模型的核心网络,采用「编码器 - 瓶颈 - 解码器」的对称结构,通过「下采样(编码器)提取全局特征→瓶颈层处理抽象特征→上采样(解码器)恢复细节 + 跳跃连接(skip connection)补充细节」实现噪声预测。
结构逻辑(结合代码)
UNET 的输入是latent (B, 4, H/8, W/8),输出是(B, 320, H/8, W/8)(后续由UNET_OutputLayer转成 4 通道噪声)。
class
UNET(nn.Module):
def __init__(self):
super().__init__()
#
1. 编码器(encoders):下采样 + 增通道,提取全局特征
编码器通过「ResidualBlock+AttentionBlock」处理特征,再用stride=2的卷积下采样(分辨率减半),同时收集每个层的输出作为「跳跃连接(skip_connections)」,供解码器使用。
self.encoders = nn.ModuleList([
# (Batch_Size, 4, Height / 8, Width
/ 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(nn.Conv2d(4, 320,
kernel_size=3, padding=1)),
# (Batch_Size, 320, Height / 8,
Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size,
320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8,
40)),
# (Batch_Size, 320, Height / 8,
Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size,
320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8,
40)),
# (Batch_Size, 320, Height / 8,
Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
SwitchSequential(nn.Conv2d(320,
320, kernel_size=3, stride=2, padding=1)),
# (Batch_Size, 320, Height / 16,
Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size,
640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8,
80)),
# (Batch_Size, 640, Height / 16,
Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size,
640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8,
80)),
# (Batch_Size, 640, Height / 16,
Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
SwitchSequential(nn.Conv2d(640,
640, kernel_size=3, stride=2, padding=1)),
# (Batch_Size, 640, Height / 32,
Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) ->
(Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8,
160)),
# (Batch_Size, 1280, Height / 32,
Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) ->
(Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8,
160)),
# (Batch_Size, 1280, Height / 32,
Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(nn.Conv2d(1280,
1280, kernel_size=3, stride=2, padding=1)),
# (Batch_Size, 1280, Height / 64,
Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
# (Batch_Size, 1280, Height / 64,
Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
])
#
2. 瓶颈层(bottleneck):分辨率最低,处理抽象特征
瓶颈层是编码器的终点、解码器的起点,分辨率最低(H/64)、通道数最高(1280),仅用「ResidualBlock+AttentionBlock+ResidualBlock」处理最抽象的全局特征(计算成本最低,适合全局依赖捕捉)。
self.bottleneck = SwitchSequential(
# (Batch_Size, 1280, Height / 64,
Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
UNET_ResidualBlock(1280, 1280),
# (Batch_Size, 1280, Height / 64,
Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
UNET_AttentionBlock(8, 160),
# (Batch_Size, 1280, Height / 64,
Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
UNET_ResidualBlock(1280, 1280),
)
#
3. 解码器(decoders):上采样 + 减通道,恢复细节
解码器与编码器对称,每个模块先将「当前特征」与「编码器对应层的 skip_connections」拼接(torch.cat,通道数相加),补充下采样丢失的细节,再通过「ResidualBlock+AttentionBlock」处理,最后用Upsample上采样(分辨率翻倍)。
self.decoders = nn.ModuleList([
# (Batch_Size, 2560, Height / 64,
Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
# (Batch_Size, 2560, Height / 64,
Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
# (Batch_Size, 2560, Height / 64,
Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) ->
(Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
# (Batch_Size, 2560, Height / 32,
Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) ->
(Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8,
160)),
# (Batch_Size, 2560, Height / 32,
Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) ->
(Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8,
160)),
# (Batch_Size, 1920, Height / 32,
Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) ->
(Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height /
16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8,
160), Upsample(1280)),
# (Batch_Size, 1920, Height / 16,
Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size,
640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8,
80)),
# (Batch_Size, 1280, Height / 16,
Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size,
640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8,
80)),
# (Batch_Size, 960, Height / 16,
Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size,
640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8,
80), Upsample(640)),
# (Batch_Size, 960, Height / 8,
Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size,
320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8,
40)),
# (Batch_Size, 640, Height / 8,
Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size,
320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8,
40)),
# (Batch_Size, 640, Height / 8,
Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size,
320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8,
40)),
])
def forward(self, x, context, time):
# x: (Batch_Size, 4, Height / 8, Width
/ 8)
# context: (Batch_Size, Seq_Len, Dim)
# time: (1, 1280)
skip_connections = []
for layers in self.encoders:
x = layers(x, context, time)
skip_connections.append(x)
x = self.bottleneck(x, context, time)
for layers in self.decoders:
# Since we always concat with the
skip connection of the encoder, the number of features increases before being
sent to the decoder's layer
x = torch.cat((x,
skip_connections.pop()), dim=1)
x = layers(x, context, time)
return x
将 UNET 解码器输出的 320 通道特征图,转成与输入 latent 一致的 4 通道(预测噪声的通道数需与 latent 相同)。
class UNET_OutputLayer(nn.Module):
def __init__(self, in_channels, out_channels): # in=320,out=4
super().__init__()
self.groupnorm = nn.GroupNorm(32, in_channels)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x): # x: (B, 320, H/8, W/8)
x = self.groupnorm(x) # 归一化
x = F.silu(x) # 激活
x = self.conv(x) # 320→4通道(分辨率不变)
return x # 输出:(B, 4, H/8, W/8) → 预测的噪声
将「时间嵌入→UNET 特征处理→输出噪声」的全流程封装为Diffusion类,是训练和推理的统一入口。
class
Diffusion(nn.Module):
def __init__(self):
super().__init__()
self.time_embedding =
TimeEmbedding(320)
self.unet = UNET()
self.final = UNET_OutputLayer(320, 4)
def forward(self, latent, context, time):
# latent: (Batch_Size, 4, Height / 8,
Width / 8)
# context: (Batch_Size, Seq_Len, Dim)
# time: (1, 320)
# (1, 320) -> (1, 1280)
time = self.time_embedding(time)
# (Batch, 4, Height / 8, Width / 8)
-> (Batch, 320, Height / 8, Width / 8)
output = self.unet(latent, context,
time)
# (Batch, 320, Height / 8, Width / 8)
-> (Batch, 4, Height / 8, Width / 8)
output = self.final(output)
# (Batch, 4, Height / 8, Width / 8)
return output
已使用 OneNote 创建。